import numpy as np
import struct
import os
import copy

from scipy.ndimage.filters import gaussian_filter


class StructReader:
	def __init__(self, data : bytes) -> None:
		self.position = 0 #type: int
		self.data = None #type: bytes
		self.data=data

	def Read(self, vars : str):
		d = struct.unpack_from(vars, self.data, self.position)
		self.position += struct.calcsize(vars)
		return d



class TransientImage:
	TiVersion00 = [84, 73, 48, 48]
	TiVersion01 = [84, 73, 48, 49]
	TiVersion04 = [84, 73, 48, 52]
	
	def __init__(self, filename : str = None) -> None:
		
		if filename is not None:
			self.LoadFile(filename)
	
	def Reload(self) ->None:
		self.LoadFile(self.filenameCache)
	
	def LoadFile(self, filename : str) -> None:	
		versionHeader = None		
		with open(filename, "rb") as f:
			versionHeader = np.fromfile(f, dtype=np.int8, count=4)
			
		if np.array_equal(TransientImage.TiVersion00, versionHeader) or np.array_equal(TransientImage.TiVersion01, versionHeader):
			self.LoadVersion01(filename)
		elif np.array_equal(TransientImage.TiVersion04, versionHeader):
			self.LoadVersion04(filename)
		else:
			raise Exception("Wrong header: ", versionHeader)

	
	def LoadVersion01(self, filename):
		with open(filename, "rb") as f:
			self.fileVersion = 1
			header = np.fromfile(f, dtype=np.int8, count=4)
			if np.array_equal(TransientImage.TiVersion00, header):
				print("Loading file version 00")
				self.uResolution, self.vResolution, self.numBins = np.fromfile(f, dtype=np.uint32, count=3)
			elif np.array_equal(TransientImage.TiVersion01, header):
				print("Loading file version 01")
				self.numBins, self.uResolution, self.vResolution = np.fromfile(f, dtype=np.uint32, count=3)
			self.tMin, self.tMax = np.fromfile(f, dtype=np.float32, count=2)
			self.tDelta = (self.tMax - self.tMin) / float(self.numBins)
			self.data = np.fromfile(f, dtype=np.float32, count=self.uResolution*self.vResolution*self.numBins)
			self.data = self.data.reshape((self.vResolution, self.uResolution, self.numBins))
			
			# image properties
			self.imageProperties = bytearray(f.read()).decode('utf-8')
			
	def LoadVersion04(self, filename):
		print("Loading file version 04")
		with open(filename, "rb") as f:
			self.fileVersion = 4
			magicNumber = np.fromfile(f, dtype=np.int8, count=4)
			self.pixelMode = np.fromfile(f, dtype=np.uint32, count=1)[0]
			self.numPixels = np.fromfile(f, dtype=np.uint32, count=1)[0]
			self.numBins = np.fromfile(f, dtype=np.uint32, count=1)[0]
			self.tMin = np.fromfile(f, dtype=np.float32, count=1)[0]
			self.tDelta = np.fromfile(f, dtype=np.float32, count=1)[0]
			self.pixelInterpretationBlockSize = np.fromfile(f, dtype=np.uint32, count=1)[0]
			
			self.tMax = self.tMin + self.tDelta*self.numBins
			
			if(10 != self.pixelMode):
				raise Exception("currently only mode 10 is supported")
			
			# pixel data
			self.data = np.fromfile(f, dtype=np.float32, count = self.numPixels*self.numBins)
			
			# pixel interpretation block
			self.uResolution = np.fromfile(f, dtype=np.uint32, count=1)[0]
			self.vResolution = np.fromfile(f, dtype=np.uint32, count=1)[0]
			self.topLeft = np.fromfile(f, dtype=np.float32, count=3)
			self.topRight = np.fromfile(f, dtype=np.float32, count=3)
			self.bottomLeft = np.fromfile(f, dtype=np.float32, count=3)
			self.bottomRight = np.fromfile(f, dtype=np.float32, count=3)
			self.laserPosition = np.fromfile(f, dtype=np.float32, count=3)
			
			self.data = self.data.reshape((self.vResolution, self.uResolution, self.numBins))
			
			# image properties
			self.imageProperties = bytearray(f.read()).decode('utf-8')
	
	def SaveFileVersion01(self, filename : str) -> None:
		with open(filename, 'wb') as file:
			file.write(struct.pack("4s", bytes("TI01", "utf-8")))
			file.write(struct.pack("III", self.numBins, self.uResolution, self.vResolution))
			file.write(struct.pack("ff", self.tMin, self.tMax))
			file.write(struct.pack("{0}f".format(self.uResolution*self.vResolution*self.numBins), *self.data.flatten()))
			
			ip = bytes(self.imageProperties, "utf-8")
			file.write(struct.pack("{}s".format(len(ip)), ip))

	def SaveFileVersion04(self, filename : str) -> None:		
		if(10 != self.pixelMode):
			raise Exception("currently only mode 10 is supported")
		if(68 != self.pixelInterpretationBlockSize):
			raise Exception("pixelInterpretationBlockSize does not indicate mode 10. Currently only mode 10 is supported")
				
		with open(filename, 'wb') as file:
			# file header
			file.write(struct.pack("4s", bytes("TI04", "utf-8")))
			file.write(struct.pack("IIIffI", self.pixelMode, self.numPixels, self.numBins,
						  self.tMin, self.tDelta, self.pixelInterpretationBlockSize))
			# pixel data
			pixelCount = self.vResolution*self.uResolution*self.numBins
			file.write(struct.pack("{}f".format(pixelCount), *self.data.flatten()))
			
			# pixel interpretation data
			file.write(struct.pack("II3f3f3f3f3f", self.uResolution, self.vResolution,
						  *self.topLeft, *self.topRight, *self.bottomLeft, *self.bottomRight, *[0, 0, 0]))
			
			ip = bytes(self.imageProperties, "utf-8")
			file.write(struct.pack("{}s".format(len(ip)), ip))
	
	def SaveFile(self, filename : str) -> None:
		self.SaveFileVersion04(filename)
		
	def PrintInfo(self) -> None:
		# print some information about the file:
		print("Dimensions: ", self.numBins, self.uResolution, self.vResolution)
		print("Range: ", self.tMin, self.tMax)
		print("Min: ", np.nanmin(self.data))
		print("Max: ", np.nanmax(self.data))
	
		print("Nans: ", np.sum(np.isnan(self.data)))
		print("Infs: ", np.sum(np.isinf(self.data)))

	def TimeIntegrated(self) -> np.array:
		return np.sum(self.data, axis=2) / float(self.data.shape[2])
	
def SliceMeter(input : TransientImage, begin : float, end : float) -> TransientImage:
	result = copy.deepcopy(input)
	offset = result.tMin
	#compute bins
	steps = (result.tMax-offset) / result.resT
	binMin = int((begin-offset)/steps)
	binMax = int((end-offset)/steps)
	
	result.data = result.data[:, :, binMin:binMax]
	result.tMin = offset + binMin*steps
	result.tMax = offset + binMax*steps
	result.resT = result.data.shape[2]
	
	return result
	
def Blur(input : TransientImage, filterSize : float) -> TransientImage:
	result = copy.deepcopy(input)
	result.data = gaussian_filter(result.data, filterSize)
	return result
